{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "42d33d77",
   "metadata": {},
   "source": [
    "# Exploring Kaputt Dataset\n",
    "The **Kaputt** dataset marks a new era for visual defect detection in **Retail Logistics**. Developed by researchers from **Amazon** and the **University of Oxford**, Kaputt contains more than **238,000 images** and **48,000 unique products**, including **over 29,000 defective instances**. It is **40x larger** than existing benchmarks such as MVTec-AD and VisA.\n",
    "\n",
    "While previous datasets focused on tightly controlled manufacturing settings, Kaputt introduces real-world complexity, products with varying shapes, materials, lighting conditions, and poses. This makes defect detection far more challenging and realistic: even state-of-the-art models \n",
    "struggle to exceed **56.9% AUROC** on this benchmark.\n",
    "\n",
    "![kaputt_overview](https://cdn.voxel51.com/kaputt_overview.webp)\n",
    "\n",
    "Dataset: [kaputt-dataset.com](https://www.kaputt-dataset.com)  \n",
    "Reference: *Kaputt: A Large-Scale Benchmark for Visual Defect Detection in Retail Logistics* (ICCV 2025)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93b3722e",
   "metadata": {},
   "source": [
    "**So, what’s the takeaway?**\n",
    "\n",
    "1. Load and explore the **Kaputt dataset** interactively using FiftyOne.  \n",
    "2. Visualize its structure and metadata fields.  \n",
    "3. Compute and analyze embeddings, find similar samples query and filter the dataset from different inputs.  \n",
    "4. Experiment with VLMs as **FastVLM** and other models from the FiftyOne Model Zoo.  \n",
    "5. Demonstrate how FiftyOne helps uncover valuable insights into data quality, bias, and model performance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebd78ea4",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "If you haven’t already, install the required packages. These cells will only install packages if they’re missing. This notebook was tested in a Python Env (py 3.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeefee84",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install fiftyone umap-learn torch torchvision pandas pyarrow"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "405487b7",
   "metadata": {},
   "source": [
    "### Request access and download the dataset locally and dataset structure\n",
    "\n",
    "1. **Fill the Request [Form](https://www.kaputt-dataset.com/)**\n",
    "\n",
    "   - Locate the access form for the Kaputt dataset, available on the dataset’s official website or publication page.  \n",
    "   - Complete all required fields with accurate information such as your **name**, **email**, and **affiliation**.\n",
    "\n",
    "2. **Submit the Form**\n",
    "\n",
    "   - Double-check your details before submission.  \n",
    "   - Click **Submit** to send your access request.\n",
    "\n",
    "3. **Check Your Email**\n",
    "\n",
    "   - Open the inbox of the email address you provided.  \n",
    "   - Look for a confirmation or dataset access message.\n",
    "\n",
    "#### Dataset structure\n",
    "\n",
    "In this notebook we will explore the complete query folder, for the other folders you can extend with the provided code.\n",
    "\n",
    "```\n",
    "kaputt/\n",
    "├── datasets/                         # Parquet metadata (index tables)\n",
    "│   ├── query-train.parquet\n",
    "│   ├── query-validation.parquet\n",
    "│   ├── query-test.parquet\n",
    "│   ├── reference-train.parquet\n",
    "│   ├── reference-validation.parquet\n",
    "│   ├── reference-test.parquet\n",
    "│   └── README.md\n",
    "│\n",
    "├── query-image/                      # Full query images (main inputs)\n",
    "│   └── data/\n",
    "│       ├── train/\n",
    "│       │   └── query-data/image/\n",
    "│       │       ├── <capture_id>.jpg\n",
    "│       │       ├── ...\n",
    "│       ├── validation/\n",
    "│       │   └── query-data/image/\n",
    "│       │       ├── <capture_id>.jpg\n",
    "│       │       ├── ...\n",
    "│       └── test/\n",
    "│           └── query-data/image/\n",
    "│               ├── <capture_id>.jpg\n",
    "│               ├── ...\n",
    "│\n",
    "├── query-crop/                       # Cropped item regions\n",
    "│   └── data/\n",
    "│       ├── train/\n",
    "│       │   └── query-data/crop/\n",
    "│       │       ├── <capture_id>.jpg\n",
    "│       │       ├── ...\n",
    "│       ├── validation/\n",
    "│       │   └── query-data/crop/\n",
    "│       │       ├── <capture_id>.jpg\n",
    "│       │       ├── ...\n",
    "│       └── test/\n",
    "│           └── query-data/crop/\n",
    "│               ├── <capture_id>.jpg\n",
    "│               ├── ...\n",
    "│\n",
    "├── query-mask/                       # Binary/segmentation masks\n",
    "│   └── data/\n",
    "│       ├── train/\n",
    "│       │   └── query-data/mask/\n",
    "│       │       ├── <capture_id>.png\n",
    "│       │       ├── ...\n",
    "│       ├── validation/\n",
    "│       │   └── query-data/mask/\n",
    "│       │       ├── <capture_id>.png\n",
    "│       │       ├── ...\n",
    "│       └── test/\n",
    "│           └── query-data/mask/\n",
    "│               ├── <capture_id>.png\n",
    "│               ├── ...\n",
    "│\n",
    "├── reference-image/                  # Reference (non-defective) images\n",
    "│   └── data/\n",
    "│       ├── train/reference-data/image/\n",
    "│       ├── validation/reference-data/image/\n",
    "│       └── test/reference-data/image/\n",
    "│\n",
    "├── reference-crop/                   # Crops for reference images\n",
    "│   └── data/\n",
    "│       ├── train/reference-data/crop/\n",
    "│       ├── validation/reference-data/crop/\n",
    "│       └── test/reference-data/crop/\n",
    "│\n",
    "├── reference-mask/                   # Segmentation masks for reference images\n",
    "│   └── data/\n",
    "│       ├── train/reference-data/mask/\n",
    "│       ├── validation/reference-data/mask/\n",
    "│       └── test/reference-data/mask/\n",
    "│\n",
    "├── sample-data/                      # Small subset for testing\n",
    "│   ├── data/\n",
    "│   │   └── train/\n",
    "│   │       ├── query-data/\n",
    "│   │       │   ├── image/\n",
    "│   │       │   ├── crop/\n",
    "│   │       │   └── mask/\n",
    "│   │       └── reference-data/\n",
    "│   │           ├── image/\n",
    "│   │           ├── crop/\n",
    "│   │           └── mask/\n",
    "│   ├── query-sample.parquet\n",
    "│   └── reference-sample.parquet\n",
    "│\n",
    "└── kaputt-release/                   # Original release version (mirrors structure above)\n",
    "    ├── train/\n",
    "    │   ├── query-data/\n",
    "    │   └── reference-data/\n",
    "    ├── validation/\n",
    "    │   ├── query-data/\n",
    "    │   └── reference-data/\n",
    "    └── test/\n",
    "        ├── query-data/\n",
    "        └── reference-data/\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6011528",
   "metadata": {},
   "source": [
    "### Import Kaputt (Query Only)\n",
    "\n",
    "This cell imports the query portion of the Kaputt dataset into FiftyOne. It reads the query Parquet files from ```/datasets/```, builds absolute paths for images, crops, and masks, and creates a FiftyOne dataset with fields for defect attributes and item metadata.\n",
    "Only train and validation splits are loaded for faster testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa364e7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import_kaputt_to_fiftyone_queries_only.py\n",
    "#\n",
    "# Media layout:\n",
    "# <ROOT>/query-image/data/<split>/query-data/image/<filename>\n",
    "# <ROOT>/query-crop/data/<split>/query-data/crop/<filename>\n",
    "# <ROOT>/query-mask/data/<split>/query-data/mask/<filename>\n",
    "#\n",
    "# Parquets live in <ROOT>/datasets/ with names like:\n",
    "# query-train.parquet, query_validation.parquet, query-train.parquet.gz, etc.\n",
    "#\n",
    "# Loads ONLY query media (+ attrs) for splits: train, validation\n",
    "\n",
    "import os\n",
    "from pathlib import Path\n",
    "import glob\n",
    "import pandas as pd\n",
    "import fiftyone as fo\n",
    "\n",
    "# ---------- CONFIG ----------\n",
    "EXTERNAL_ROOT = \"/your/main/folder/for/kaputt\"   # <-- your kaputt root\n",
    "DATASET_NAME  = \"kaputt\"\n",
    "PARQUET_DIR   = \"datasets\"\n",
    "SPLITS = (\"train\", \"validation\")  # <-- test excluded\n",
    "NUM_WORKERS_METADATA = 8\n",
    "# ---------------------------\n",
    "\n",
    "root = Path(EXTERNAL_ROOT).resolve()\n",
    "pdir = (root / PARQUET_DIR).resolve()\n",
    "\n",
    "# kind -> (top media folder, terminal folder under query-data)\n",
    "# Final path: <root>/<media_folder>/data/<split>/query-data/<terminal>/<filename>\n",
    "MEDIA_MAPPING = {\n",
    "    \"query_image\": (\"query-image\", \"image\"),\n",
    "    \"query_crop\":  (\"query-crop\",  \"crop\"),\n",
    "    \"query_mask\":  (\"query-mask\",  \"mask\"),\n",
    "}\n",
    "\n",
    "def _basename(pathlike) -> str | None:\n",
    "    if pathlike is None or (isinstance(pathlike, float) and pd.isna(pathlike)):\n",
    "        return None\n",
    "    s = str(pathlike).replace(\"\\\\\", \"/\").rstrip(\"/\")\n",
    "    base = os.path.basename(s)\n",
    "    return base or None\n",
    "\n",
    "def resolve_path_from_schema(rel, split: str, kind: str) -> str | None:\n",
    "    \"\"\"\n",
    "    Build absolute path:\n",
    "        <root>/<media_folder>/data/<split>/query-data/<terminal>/<filename>\n",
    "    rel may be absolute, relative, or just a filename.\n",
    "    \"\"\"\n",
    "    if rel is None or (isinstance(rel, float) and pd.isna(rel)):\n",
    "        return None\n",
    "\n",
    "    if os.path.isabs(str(rel)) and os.path.exists(str(rel)):\n",
    "        return str(Path(rel).resolve())\n",
    "\n",
    "    fname = _basename(rel)\n",
    "    if not fname:\n",
    "        return None\n",
    "\n",
    "    media_folder, terminal = MEDIA_MAPPING[kind]\n",
    "    abs_path = root / media_folder / \"data\" / split / \"query-data\" / terminal / fname\n",
    "    return str(abs_path)\n",
    "\n",
    "def find_query_parquet(split: str) -> Path:\n",
    "    \"\"\"\n",
    "    Robustly find a parquet file for a split inside <root>/datasets/.\n",
    "    Supports common variants and .parquet.gz.\n",
    "    \"\"\"\n",
    "    candidates = []\n",
    "    patterns = [\n",
    "        f\"query-{split}.parquet\",\n",
    "        f\"query_{split}.parquet\",\n",
    "        f\"query-*{split}*.parquet\",\n",
    "        f\"query-{split}.parquet.gz\",\n",
    "        f\"query_{split}.parquet.gz\",\n",
    "        f\"query-*{split}*.parquet.gz\",\n",
    "    ]\n",
    "    for pat in patterns:\n",
    "        candidates.extend(glob.glob(str(pdir / pat)))\n",
    "\n",
    "    candidates = sorted(set(candidates))\n",
    "    if not candidates:\n",
    "        raise FileNotFoundError(\n",
    "            f\"No query parquet found for split '{split}'. \"\n",
    "            f\"Tried under: {pdir}\\nPatterns: {patterns}\"\n",
    "        )\n",
    "\n",
    "    preferred = pdir / f\"query-{split}.parquet\"\n",
    "    if preferred.exists():\n",
    "        return preferred\n",
    "\n",
    "    return Path(candidates[0])\n",
    "\n",
    "def read_parquet_robust(fp: Path) -> pd.DataFrame:\n",
    "    try:\n",
    "        return pd.read_parquet(fp)\n",
    "    except Exception as e:\n",
    "        raise RuntimeError(\n",
    "            f\"Failed to read parquet file: {fp}\\n\"\n",
    "            f\"Error: {type(e).__name__}: {e}\\n\"\n",
    "            f\"Tip: ensure it's a valid parquet (or parquet.gz) file.\"\n",
    "        ) from e\n",
    "\n",
    "def prepare_split(split: str):\n",
    "    try:\n",
    "        q_fp = find_query_parquet(split)\n",
    "    except FileNotFoundError as e:\n",
    "        print(f\"[INFO] Skipping split '{split}': {e}\")\n",
    "        return []\n",
    "\n",
    "    print(f\"[INFO] Using query parquet for '{split}': {q_fp}\")\n",
    "    q = read_parquet_robust(q_fp)\n",
    "\n",
    "    # Normalize optional columns\n",
    "    for col in (\"defect\", \"major_defect\", \"defect_types\"):\n",
    "        if col not in q.columns:\n",
    "            q[col] = None\n",
    "\n",
    "    samples, skipped_missing = [], 0\n",
    "\n",
    "    for _, row in q.iterrows():\n",
    "        query_image = resolve_path_from_schema(row.get(\"query_image\"), split, \"query_image\")\n",
    "        query_crop  = resolve_path_from_schema(row.get(\"query_crop\"),  split, \"query_crop\")\n",
    "        query_mask  = resolve_path_from_schema(row.get(\"query_mask\"),  split, \"query_mask\")\n",
    "\n",
    "        # Require main media\n",
    "        if not (query_image and os.path.exists(query_image)):\n",
    "            skipped_missing += 1\n",
    "            continue\n",
    "\n",
    "        # defect_types -> list\n",
    "        dtypes = row.get(\"defect_types\", None)\n",
    "        if isinstance(dtypes, str):\n",
    "            dtypes = [s.strip() for s in dtypes.split(\",\") if s.strip()]\n",
    "        elif dtypes is None or (isinstance(dtypes, float) and pd.isna(dtypes)):\n",
    "            dtypes = []\n",
    "\n",
    "        s = fo.Sample(filepath=query_image, tags=[split])\n",
    "        if query_crop and os.path.exists(query_crop):\n",
    "            s[\"query_crop\"] = query_crop\n",
    "        if query_mask and os.path.exists(query_mask):\n",
    "            s[\"query_mask\"] = query_mask\n",
    "\n",
    "        s[\"defect\"]          = bool(row.get(\"defect\", False))\n",
    "        s[\"major_defect\"]    = bool(row.get(\"major_defect\", False))\n",
    "        s[\"defect_types\"]    = dtypes\n",
    "        s[\"item_material\"]   = row.get(\"item_material\", None)\n",
    "        s[\"capture_id\"]      = row.get(\"capture_id\", None)\n",
    "        s[\"item_identifier\"] = row.get(\"item_identifier\", None)\n",
    "        s[\"split\"]           = split\n",
    "\n",
    "        samples.append(s)\n",
    "\n",
    "    if skipped_missing:\n",
    "        print(f\"[WARN] Split '{split}': skipped {skipped_missing} samples with missing query_image\")\n",
    "\n",
    "    return samples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23a640ec",
   "metadata": {},
   "source": [
    "## Create and Save FiftyOne Dataset\n",
    "\n",
    "This cell creates a new FiftyOne dataset named after DATASET_NAME, adds samples for each available split (train and validation), and computes image metadata (dimensions, channels, etc.). If a dataset with the same name already exists, it is replaced to ensure a clean import.\n",
    "\n",
    "**After adding all samples:**\n",
    "\n",
    "- The dataset is saved and marked as persistent.\n",
    "- Split counts are printed for verification.\n",
    "- Metadata is computed in parallel using ```NUM_WORKERS_METADATA```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e70ea73",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "if fo.dataset_exists(DATASET_NAME):\n",
    "    fo.delete_dataset(DATASET_NAME)\n",
    "ds = fo.Dataset(DATASET_NAME)\n",
    "\n",
    "total_added = 0\n",
    "for split in SPLITS:\n",
    "    samples = prepare_split(split)\n",
    "    if samples:\n",
    "        ds.add_samples(samples)\n",
    "        print(f\"[INFO] Added {len(samples)} samples for split '{split}'\")\n",
    "        total_added += len(samples)\n",
    "\n",
    "ds.save()\n",
    "print(ds)\n",
    "try:\n",
    "    print(\"Counts by split:\", ds.count_values(\"split\"))\n",
    "except Exception as e:\n",
    "    print(\"[INFO] Could not aggregate counts by split:\", e)\n",
    "\n",
    "if len(ds) > 0:\n",
    "    try:\n",
    "        print(\"[INFO] Computing metadata...\")\n",
    "        ds.compute_metadata(overwrite=True, num_workers=NUM_WORKERS_METADATA)\n",
    "        print(\"[INFO] Done computing metadata.\")\n",
    "    except Exception as e:\n",
    "        print(f\"[WARN] Failed to compute metadata: {e}\")\n",
    "else:\n",
    "    print(\"[WARN] Dataset is empty; skipped metadata computation\")\n",
    "\n",
    "ds.persistent = True\n",
    "print(f\"[INFO] Dataset '{DATASET_NAME}' is ready with {total_added} samples.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f993effe",
   "metadata": {},
   "source": [
    "### Compute CLIP Embeddings and Similarity Index\n",
    "\n",
    "This cell uses the ```CLIP ViT-B/32``` model from the **FiftyOne Model Zoo** to generate visual embeddings for all samples in the dataset.\n",
    "The embeddings are stored in the field ```clip_embeddings``` and used to build a similarity index (```clip_sim```) via the **FiftyOne Brain** module.\n",
    "\n",
    "This enables efficient image similarity search, semantic clustering, and content-based exploration directly within FiftyOne. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6a8aabf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import fiftyone.zoo as foz\n",
    "import fiftyone.brain as fob\n",
    "\n",
    "# Load CLIP model\n",
    "model = foz.load_zoo_model(\"clip-vit-base32-torch\")\n",
    "\n",
    "# Compute embeddings for all samples\n",
    "ds.compute_embeddings(model, embeddings_field=\"clip_embeddings\")\n",
    "\n",
    "# Create similarity index from pre-computed embeddings\n",
    "fob.compute_similarity(\n",
    "    ds,\n",
    "    model=\"clip-vit-base32-torch\",\n",
    "    embeddings=\"clip_embeddings\",\n",
    "    brain_key=\"clip_sim\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d43bed62",
   "metadata": {},
   "source": [
    "### Dimensionality Reduction with UMAP\n",
    "\n",
    "This cell applies **UMAP** ```(Uniform Manifold Approximation and Projection)``` to the CLIP embeddings stored in ```clip_embeddings```.\n",
    "It computes a 2D visualization of the dataset’s feature space using the FiftyOne Brain module and saves it under the key ```clip_vis```.\n",
    "\n",
    "This allows interactive exploration of the dataset in embedding space, revealing visual clusters and relationships between samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3126cd0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dimensionality reduction using UMAP on the embeddings\n",
    "fob.compute_visualization(\n",
    "    ds,\n",
    "    embeddings=\"clip_embeddings\",\n",
    "    method=\"umap\",\n",
    "    brain_key=\"clip_vis\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ee0d365",
   "metadata": {},
   "source": [
    "### Create Index for Faster Filtering\n",
    "\n",
    "This cell creates a compound index on the fields ```defect_types```, ```item_material```, and ```split``` to optimize query performance in FiftyOne. By indexing these commonly filtered fields, dataset operations such as searching, filtering, and aggregating by defect category or material type become significantly faster, especially when working with **large datasets**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7183a64c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create an index to speed up filtering by defect types, item material, and split\n",
    "ds.create_index(\n",
    "    [(\"defect_types\", 1), (\"item_material\", 1), (\"split\", 1)]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e038d698",
   "metadata": {},
   "source": [
    "Open a web browser session to play interactively with your dataset, metadata and embeddings."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bf41133",
   "metadata": {},
   "source": [
    "### Explore visually in the FiftyOne App"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c94a3ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Launch FiftyOne App\n",
    "session = fo.launch_app(ds, port=5151, auto=False)\n",
    "session.wait()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22978916",
   "metadata": {},
   "source": [
    "## Apply FastVLM to Evaluate Defect Severity\n",
    "\n",
    "This cells integrate the ```FastVLM``` model from the **open-source community** into FiftyOne to analyze the Kaputt dataset.\n",
    "It registers the model source from GitHub, downloads the desired FastVLM variant (```0.5B```, ```1.5B```, or ```7B```), and loads it into the environment.\n",
    "\n",
    "Using a detailed prompt adapted from the model authors, the system asks the model to reason about item condition and defect severity.\n",
    "The model’s output is stored in a new field called result, containing structured JSON with:\n",
    "\n",
    "```\"condition\"``` → ```\"DAMAGED\"``` or ```\"UNDAMAGED\"```\n",
    "\n",
    "```\"severity\"``` → Numeric score from 0 (```pristine```) to 10 (```completely destroyed```)\n",
    "\n",
    "This enables analysis of how a large vision-language model interprets real-world packaging damage across the **Kaputt dataset**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf463d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "import fiftyone.zoo as foz\n",
    "\n",
    "# Register the model source\n",
    "foz.register_zoo_model_source(\n",
    "    \"https://github.com/harpreetsahota204/fast_vlm\",\n",
    "    overwrite=True\n",
    ")\n",
    "\n",
    "# Download the desired model variant (first time only)\n",
    "# Choose from: \"apple/FastVLM-0.5B\", \"apple/FastVLM-1.5B\", or \"apple/FastVLM-7B\"\n",
    "foz.download_zoo_model(\n",
    "    \"https://github.com/harpreetsahota204/fast_vlm\",\n",
    "    model_name=\"apple/FastVLM-7B\"  # Change to desired model variant\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca5ecdb0",
   "metadata": {},
   "source": [
    "### Add predictions to your dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3e011ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the model\n",
    "model = foz.load_zoo_model(\"apple/FastVLM-1.5B\")\n",
    "\n",
    "# Generate creative content with a custom prompt\n",
    "model.prompt = (\"You are a highly skilled subject matter expert \"\n",
    "    \"for inventory quality assurance and control. The presented \"\n",
    "    \"image shows an item inside a tray. You have to determine whether \"\n",
    "    \"the item is in pristine condition and can be sold as new and shipped \"\n",
    "    \"to the customer as is, or whether it is damaged in any way and needs \"\n",
    "    \"further attention before it can be shipped. Consider the \"\n",
    "    \"following damage categories: crushed, tear, hole, deformed, ripped, \"\n",
    "    \"deconstructed. Typical defects also include open boxes, or damaged and ripped packaging. \"\n",
    "    \"Sometimes if the packaging is damaged, the item itself may become deconstructed and parts \" \n",
    "    \"of the content may fall out. The container itself may be dirty \"\n",
    "    \"which should not count as damage. However, if there is \"\n",
    "    \"spillage that originated from a liquid item, then it must \"\n",
    "    \"be called out as damage. Pay close attention to books \"\n",
    "    \"and especially to corners of front or back pages. Moreover, \"\n",
    "    \"items that a deconstructed, i.e. where the original \"\n",
    "    \"packaging is damaged or fell off, should be flagged as \"\n",
    "    \"damaged. In addition to the final decision specified by \"\n",
    "    \"DAMAGED or UNDAMAGED, please also provide \"\n",
    "    \"the severity score on a scale from 0 (pristine condition) \"\n",
    "    \"to 10 (completely destroyed). Think step-by-step and \"\n",
    "    \"provide the final response as json with keys 'condition' \"\n",
    "    \"and 'severity'.\"\n",
    ")\n",
    "ds.apply_model(model, label_field=\"result\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2a7c2ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "session = fo.launch_app(ds, port=5151, auto=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eee2961a",
   "metadata": {},
   "source": [
    "### Creating a Grouped Dataset with Images, Crops, and Masks in FiftyOne\n",
    "\n",
    "This code demonstrates how to create a **grouped dataset** in FiftyOne, where each group can contain up to three related slices: the original image, a crop, and a mask. This structure is useful for organizing and visualizing multimodal or multiview data, such as associating each image with its corresponding crop and segmentation mask.\n",
    "\n",
    "**Key steps:**\n",
    "\n",
    "1. **Dataset Setup**:  \n",
    "   - Checks if a grouped dataset with the specified name exists and deletes it if so.\n",
    "   - Creates a new grouped dataset and adds a group field (with `\"image\"` as the default slice).\n",
    "\n",
    "2. **Sample Processing**:  \n",
    "   - Iterates through each sample in the original dataset.\n",
    "   - Copies relevant fields (excluding system fields) to new grouped samples.\n",
    "\n",
    "3. **Image Slice**:  \n",
    "   - Always adds the original image as a slice in the group.\n",
    "\n",
    "4. **Crop Slice**:  \n",
    "   - If a crop is available (`sample.query_crop`), adds it as a separate slice.\n",
    "\n",
    "5. **Mask Slice**:  \n",
    "   - If a mask is available (`sample.query_mask`), reads and normalizes it to the 0-255 range.\n",
    "   - Applies a binary threshold to create a binarized mask.\n",
    "   - Only adds the mask slice if it contains meaningful (non-black) data.\n",
    "\n",
    "6. **Debugging and Statistics**:  \n",
    "   - Tracks and prints the number of total samples processed, masks added, and masks skipped.\n",
    "   - Prints a summary of the grouped dataset, including the number of mask samples.\n",
    "\n",
    "This approach leverages FiftyOne's native grouping feature, which is ideal for paired or multimodal data exploration and visualization. For more details on grouped datasets and their use cases, see the [FiftyOne grouped datasets documentation](https://docs.voxel51.com/user_guide/groups.html) and related [example notebooks](https://github.com/voxel51/fiftyone-examples/blob/master/examples/Grouped%20Datasets.ipynb).\n",
    "\n",
    "> _Grouped datasets in FiftyOne allow you to organize related samples (such as images, masks, and crops) under a common group, enabling synchronized visualization and analysis across different data modalities or views._  \n",
    "> [FiftyOne User Guide: Grouped Datasets](https://docs.voxel51.com/user_guide/groups.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8937b588",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "\n",
    "# Delete and recreate the grouped dataset\n",
    "grouped_dataset_name = \"kaputt_grouped\"\n",
    "if grouped_dataset_name in fo.list_datasets():\n",
    "    fo.delete_dataset(grouped_dataset_name)\n",
    "grouped_dataset = fo.Dataset(grouped_dataset_name, persistent=True, overwrite=True)\n",
    "grouped_dataset.add_group_field(\"group\", default=\"image\")\n",
    "\n",
    "grouped_samples = []\n",
    "\n",
    "# Counters for debugging\n",
    "total_samples = 0\n",
    "masks_added = 0\n",
    "masks_skipped = 0\n",
    "\n",
    "def normalize_mask(mask_img):\n",
    "    \"\"\"Normalize mask values to 0-255 range\"\"\"\n",
    "    if mask_img is None or mask_img.size == 0:\n",
    "        return None\n",
    "    \n",
    "    min_val = mask_img.min()\n",
    "    max_val = mask_img.max()\n",
    "    \n",
    "    # If mask is all zeros or constant, return None\n",
    "    if min_val == max_val:\n",
    "        return None\n",
    "    \n",
    "    # Normalize to 0-255\n",
    "    normalized = ((mask_img - min_val) / (max_val - min_val) * 255).astype(np.uint8)\n",
    "    return normalized\n",
    "\n",
    "for sample in ds.iter_samples(progress=True):\n",
    "    total_samples += 1\n",
    "    group = fo.Group()\n",
    "\n",
    "    # Prepare fields to copy (excluding system fields)\n",
    "    fields_to_copy = [\n",
    "        f for f in sample.field_names\n",
    "        if f not in (\"id\", \"filepath\", \"group\", \"tags\", \"metadata\")\n",
    "    ]\n",
    "\n",
    "    # --- Original image slice ---\n",
    "    image_sample = fo.Sample(\n",
    "        filepath=sample.filepath,\n",
    "        group=group.element(\"image\"),\n",
    "    )\n",
    "    for field in fields_to_copy:\n",
    "        image_sample[field] = sample[field]\n",
    "    if sample.metadata is not None:\n",
    "        image_sample.metadata = sample.metadata\n",
    "\n",
    "    # --- Crop slice (if available) ---\n",
    "    crop_sample = None\n",
    "    if sample.query_crop:\n",
    "        crop_sample = fo.Sample(\n",
    "            filepath=sample.query_crop,\n",
    "            group=group.element(\"crop\"),\n",
    "        )\n",
    "        for field in fields_to_copy:\n",
    "            crop_sample[field] = sample[field]\n",
    "\n",
    "    # --- Mask slice (if available and valid) ---\n",
    "    mask_sample = None\n",
    "    if sample.query_mask:\n",
    "        mask_img = cv2.imread(sample.query_mask, cv2.IMREAD_GRAYSCALE)\n",
    "        if mask_img is not None:\n",
    "            # First normalize the mask to 0-255 range\n",
    "            normalized_mask = normalize_mask(mask_img)\n",
    "            \n",
    "            if normalized_mask is not None:\n",
    "                # Now apply threshold on normalized mask\n",
    "                _, bin_mask = cv2.threshold(normalized_mask, 127, 255, cv2.THRESH_BINARY)\n",
    "                \n",
    "                # Check if the mask contains any white pixels\n",
    "                if np.any(bin_mask):\n",
    "                    mask_dir = os.path.join(os.path.dirname(sample.query_mask), \"normalized_binarized_masks\")\n",
    "                    os.makedirs(mask_dir, exist_ok=True)\n",
    "                    mask_filename = os.path.basename(sample.query_mask)\n",
    "                    bin_mask_path = os.path.join(mask_dir, mask_filename)\n",
    "                    cv2.imwrite(bin_mask_path, bin_mask)\n",
    "                    \n",
    "                    mask_sample = fo.Sample(\n",
    "                        filepath=bin_mask_path,\n",
    "                        group=group.element(\"mask\"),\n",
    "                    )\n",
    "                    for field in fields_to_copy:\n",
    "                        mask_sample[field] = sample[field]\n",
    "                    \n",
    "                    masks_added += 1\n",
    "                else:\n",
    "                    masks_skipped += 1\n",
    "                    print(f\"[DEBUG] Skipped mask (all black after threshold): {sample.query_mask}\")\n",
    "            else:\n",
    "                masks_skipped += 1\n",
    "                print(f\"[DEBUG] Skipped mask (constant values): {sample.query_mask}\")\n",
    "        else:\n",
    "            masks_skipped += 1\n",
    "            print(f\"[DEBUG] Could not read mask: {sample.query_mask}\")\n",
    "\n",
    "    # Add all available slices to the group\n",
    "    for s in (image_sample, crop_sample, mask_sample):\n",
    "        if s is not None:\n",
    "            grouped_samples.append(s)\n",
    "\n",
    "# Add all grouped samples to the new dataset\n",
    "grouped_dataset.add_samples(grouped_samples)\n",
    "\n",
    "# Print summary\n",
    "print(f\"\\n[INFO] Processing complete:\")\n",
    "print(f\"  Total samples processed: {total_samples}\")\n",
    "print(f\"  Masks successfully added: {masks_added}\")\n",
    "print(f\"  Masks skipped: {masks_skipped}\")\n",
    "print(f\"  Total grouped samples created: {len(grouped_samples)}\")\n",
    "\n",
    "# Verify the grouped dataset\n",
    "print(f\"\\n[INFO] Grouped dataset info:\")\n",
    "print(f\"  Total samples: {len(grouped_dataset)}\")\n",
    "print(f\"  Group slices: {grouped_dataset.group_slices}\")\n",
    "\n",
    "# Check how many mask samples were added\n",
    "mask_count = len(grouped_dataset.match({\"group.name\": \"mask\"}))\n",
    "print(f\"  Mask samples in dataset: {mask_count}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02187a0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped_dataset.persistent = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9210aa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "session2 = fo.launch_app(grouped_dataset, port=5152, auto=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93434e17",
   "metadata": {},
   "source": [
    "### Computing Embeddings and Similarity Index for Crop Slices in FiftyOne\n",
    "\n",
    "This code demonstrates how to compute embeddings for the \"crop\" slices in a grouped dataset and then create a similarity index for these crops using the CLIP model from the FiftyOne Model Zoo.\n",
    "\n",
    "**Workflow:**\n",
    "\n",
    "1. **Select Crop Slices:**  \n",
    "   Use `select_group_slices(\"crop\")` to create a flattened view containing only the crop samples from your grouped dataset.\n",
    "\n",
    "2. **Load CLIP Model:**  \n",
    "   Load the `\"clip-vit-base32-torch\"` model from the FiftyOne Model Zoo, which supports generating embeddings for images and patches [see: Model Zoo API Reference](https://docs.voxel51.com/model_zoo/api.html#generating-embeddings-with-zoo-models).\n",
    "\n",
    "3. **Compute Embeddings:**  \n",
    "   Compute embeddings for all crop samples and store them in the `\"crop_embeddings\"` field. This can be done using the `compute_embeddings()` method, which works with any model that exposes embeddings [see: Model Zoo API Reference](https://docs.voxel51.com/model_zoo/api.html#generating-embeddings-with-zoo-models).\n",
    "\n",
    "4. **Create Similarity Index:**  \n",
    "   Use `compute_similarity()` to create a similarity index over the crop samples, specifying the model and the field containing the precomputed embeddings. This enables similarity search and sorting by similarity for the crop slices [see: Creating an index](https://docs.voxel51.com/brain.html#creating-an-index)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "762a793d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import fiftyone.zoo as foz\n",
    "import fiftyone.brain as fob\n",
    "\n",
    "# Get all crop samples as a flattened view\n",
    "crop_view = grouped_dataset.select_group_slices(\"crop\")\n",
    "\n",
    "print(f\"[INFO] Crop samples: {len(crop_view)}\")\n",
    "\n",
    "# Compute embeddings on crop samples\n",
    "model = foz.load_zoo_model(\"clip-vit-base32-torch\")\n",
    "crop_view.compute_embeddings(model, embeddings_field=\"crop_embeddings\")\n",
    "\n",
    "# Create similarity index for crops\n",
    "fob.compute_similarity(\n",
    "    crop_view,\n",
    "    model=\"clip-vit-base32-torch\",\n",
    "    embeddings=\"crop_embeddings\",\n",
    "    brain_key=\"crop_sim\",\n",
    ")\n",
    "\n",
    "print(\"[INFO] Crop embeddings computed successfully!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ac10967",
   "metadata": {},
   "source": [
    "### Visualizing Embeddings for Crop Slices in FiftyOne\n",
    "\n",
    "This code demonstrates how to compute and visualize embeddings for the \"crop\" slices of a grouped dataset using FiftyOne Brain and the CLIP model.\n",
    "\n",
    "**Steps:**\n",
    "\n",
    "1. **Select Crop Slices:**  \n",
    "   Use `select_group_slices(\"crop\")` to obtain a flattened view containing only the crop samples from your grouped dataset.\n",
    "\n",
    "2. **Compute Embeddings and Visualization:**  \n",
    "   Call `fob.compute_visualization()` on the crop samples, specifying the CLIP model (`\"clip-vit-base32-torch\"`) and a `brain_key` to store the results. This function computes embeddings for each crop and projects them into a low-dimensional space (e.g., 2D) for visualization and interactive exploration in the FiftyOne App."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95a7cd7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import fiftyone.brain as fob\n",
    "# Get all the crop samples from the dataset\n",
    "flattened_crops = grouped_dataset.select_group_slices(\"crop\")\n",
    "\n",
    "# Compute embeddings and visualization\n",
    "results = fob.compute_visualization(\n",
    "    flattened_crops,  \n",
    "    brain_key=\"crop_embedding_viz\",\n",
    "    model=\"clip-vit-base32-torch\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc4fb3b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "session2 = fo.launch_app(flattened_crops, port=5152, auto=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d902931",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- You explored your data and model predictions in FiftyOne\n",
    "- You evaluated performance and inspected edge cases\n",
    "- You identified concrete next steps to improve data/model quality\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kaputt_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
